package com.example.springboot.cros.filter;


import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.Order;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.mvc.condition.PatternsRequestCondition;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

import javax.servlet.*;
import javax.servlet.FilterConfig;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.io.Reader;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/**
 * 接口请求参数验证
 *
 * 1. 使用 RequestMappingHandlerMapping 获取所有有效Url
 * 2. HttpServletRequest 中 getInputStream 只能读取一次，需要使用 wrapper 包装读取想后传递
 */
@Order(1)
@Slf4j
@WebFilter(filterName = "signValidateFilter", urlPatterns = "/*")
public class SignValidateFilter implements Filter {

    //有效url
    private Set<String> validUrl = new TreeSet<>();

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        ServletContext context = filterConfig.getServletContext();
        ApplicationContext ac = (ApplicationContext) context.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE);

        //初始化 所有有效Controller的url
        RequestMappingHandlerMapping bean = ac.getBean(RequestMappingHandlerMapping.class);
        Map<RequestMappingInfo, HandlerMethod> handlerMethods = bean.getHandlerMethods();
        for (RequestMappingInfo rmi : handlerMethods.keySet()) {
            PatternsRequestCondition pc = rmi.getPatternsCondition();
            Set<String> pSet = pc.getPatterns();
            validUrl.addAll(pSet);
        }
    }


    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {

        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        String servletPath = httpServletRequest.getServletPath();
        if (validUrl.contains(servletPath)) {
            // 获取请求的参数信息
            ServletRequest requestWrapper = null;
            if (httpServletRequest instanceof HttpServletRequest) {
                requestWrapper = new CusRequestWrapper(httpServletRequest);
            }
            Reader reader = requestWrapper.getReader();
            // 读取Request 数据
            String reqJson = IOUtils.toString(reader);

            //验证请求合法性
            log.info("***[1]checkValidity**验证请求合法***");

            chain.doFilter(requestWrapper, response);
        } else {
            chain.doFilter(request, response);
        }
    }

    /**
     * 效验参数 是否为合法请求
     *
     * @param param
     * @return
     */
    private void checkValidity(Object param) {
        return;
    }


}
